Zbiór danych: https://archive.ics.uci.edu/ml/datasets/Polish+companies+bankruptcy+data
import os
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import numpy as np
import pandas as pd
import dalex as dx
import shap
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.metrics import confusion_matrix, precision_score, recall_score, roc_auc_score, roc_curve, auc, f1_score, classification_report, ConfusionMatrixDisplay
from xgboost import XGBClassifier
import joblib
plt.rcParams['figure.figsize'] = (16,6)
sns.set(font_scale = 1.2)
sns.set_style("ticks",{'axes.grid' : True})
sns.set_palette("deep")
## procedury pomocnicze
# macierz pomyłek
def plot_confusion_matrix(y, y_pred, title, ax):
conf_mx = confusion_matrix(y, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=conf_mx, display_labels=[0, 1])
disp.plot(colorbar=False, cmap=plt.cm.Blues, ax=ax, values_format='d')
ax.grid(False)
ax.set_title(title)
def print_precision_recall_report(clf, X_train, y_train, X_test, y_test, X_val, y_val):
y_train_pred = clf.predict(X_train)
y_test_pred = clf.predict(X_test)
y_val_pred = clf.predict(X_val)
# tabelki precision / recall / f1
def print_precision_recall_table(y, y_pred, title='Zbiór ...'):
s = pd.DataFrame(classification_report(
y, y_pred,
output_dict = True
)).iloc[:-1, :-2].to_string(float_format=lambda x: "{:7.4f}".format(x))
print(title.center(s.index('\n'), '-'))
print(s)
print()
print_precision_recall_table(y_train, y_train_pred, 'Zbiór treningowy')
print_precision_recall_table(y_test, y_test_pred, 'Zbiór testowy')
print_precision_recall_table(y_val, y_val_pred, 'Zbiór walidacyjny')
fig, axs = plt.subplots(1, 3, figsize=(9, 3))
plot_confusion_matrix(y_train, y_train_pred, 'Training set', axs[0])
plot_confusion_matrix(y_test, y_test_pred, 'Test set', axs[1])
plot_confusion_matrix(y_val, y_val_pred, 'Validation set', axs[2])
axs[1].set_ylabel(None)
axs[2].set_ylabel(None)
plt.savefig("images/xgboost_matrix.png", dpi=200, bbox_inches='tight')
plt.show()
# roc auc
def plot_roc_auc(clf, X_train, y_train, X_test, y_test, X_val, y_val):
def fpr_tpr_rocauc(X, y):
y_score = clf.predict_proba(X)[:, 1]
fpr, tpr, _ = roc_curve(y, y_score)
roc_auc = auc(fpr, tpr)
return (fpr, tpr, roc_auc)
fpr_train, tpr_train, roc_auc_train = fpr_tpr_rocauc(X_train, y_train)
fpr_test, tpr_test, roc_auc_test = fpr_tpr_rocauc(X_test, y_test)
fpr_val, tpr_val, roc_auc_val = fpr_tpr_rocauc(X_val, y_val)
plt.figure(figsize=(6, 6))
plt.plot(fpr_train, tpr_train, color="steelblue", lw=2, label="Train ROC curve (area = %0.2f)" % roc_auc_train)
plt.plot(fpr_test, tpr_test, color="darkorange", lw=2, label="Test ROC curve (area = %0.2f)" % roc_auc_test)
plt.plot(fpr_val, tpr_val, color="darkgreen", lw=2, label="Val ROC curve (area = %0.2f)" % roc_auc_val)
plt.plot([0, 1], [0, 1], color="black", lw=2, linestyle="--")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC curve")
plt.legend(loc="lower right")
plt.savefig("images/xgboost_roc.png", dpi=200, bbox_inches='tight')
plt.show()
df = pd.read_csv('dataset.csv')
df.head()
| Attr1 | Attr2 | Attr3 | Attr4 | Attr5 | Attr6 | Attr7 | Attr8 | Attr9 | Attr10 | ... | Attr58 | Attr59 | Attr60 | Attr61 | Attr62 | Attr63 | Attr64 | class | bankruptcy_after | year | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0.200550 | 0.37951 | 0.39641 | 2.0472 | 32.3510 | 0.38825 | 0.249760 | 1.33050 | 1.1389 | 0.50494 | ... | 0.87804 | 0.001924 | 8.4160 | 5.1372 | 82.658 | 4.4158 | 7.4277 | 0 | 0 | 1 |
| 1 | 0.209120 | 0.49988 | 0.47225 | 1.9447 | 14.7860 | 0.00000 | 0.258340 | 0.99601 | 1.6996 | 0.49788 | ... | 0.85300 | 0.000000 | 4.1486 | 3.2732 | 107.350 | 3.4000 | 60.9870 | 0 | 0 | 1 |
| 2 | 0.248660 | 0.69592 | 0.26713 | 1.5548 | -1.1523 | 0.00000 | 0.309060 | 0.43695 | 1.3090 | 0.30408 | ... | 0.76599 | 0.694840 | 4.9909 | 3.9510 | 134.270 | 2.7185 | 5.2078 | 0 | 0 | 1 |
| 3 | 0.081483 | 0.30734 | 0.45879 | 2.4928 | 51.9520 | 0.14988 | 0.092704 | 1.86610 | 1.0571 | 0.57353 | ... | 0.94598 | 0.000000 | 4.5746 | 3.6147 | 86.435 | 4.2228 | 5.5497 | 0 | 0 | 1 |
| 4 | 0.187320 | 0.61323 | 0.22960 | 1.4063 | -7.3128 | 0.18732 | 0.187320 | 0.63070 | 1.1559 | 0.38677 | ... | 0.86515 | 0.124440 | 6.3985 | 4.3158 | 127.210 | 2.8692 | 7.8980 | 0 | 0 | 1 |
5 rows × 67 columns
y = df['class']
y2 = df['bankruptcy_after']
X = df.drop(['class', 'bankruptcy_after', 'year'], axis = 1)
# usunięcie kolumn z wieloma brakami
X = X.drop(["Attr27","Attr21","Attr37"], axis = 1)
colnames = X.columns
# usuwamy wiersze, ktore mają więcej niż 7 brakujących wartości
y = y[X.isna().sum(axis=1) <= 7]
y2 = y2[X.isna().sum(axis=1) <= 7]
X = X[X.isna().sum(axis=1) <= 7]
# przycinanie outlierów + dodanie kolumny informującej o ich liczbie
class OutlierCutter(BaseEstimator, TransformerMixin):
def __init__(self):
self._Q025 = None
self._Q975 = None
def fit(self, X, y=None):
X = pd.DataFrame(X)
self._Q025 = X.quantile(0.025)
self._Q975 = X.quantile(0.975)
return self
def transform(self, X):
tmp = pd.DataFrame(X.copy())
tmp['outliers_count'] = ((tmp < self._Q025) | (tmp > self._Q975)).sum(axis=1)
for col in tmp.columns[:-1]:
tmp.loc[tmp[col] < self._Q025[col], col] = self._Q025[col]
tmp.loc[tmp[col] > self._Q975[col], col] = self._Q975[col]
return tmp
preprocessing = make_pipeline(
SimpleImputer(missing_values=np.nan, strategy='median'), # zastępujemy braki danych w kolumnach medianą
OutlierCutter(),
StandardScaler()
)
X_train_test, X_val, y_train_test, y_val = train_test_split(X, y, random_state=2137, test_size=0.2, stratify=y)
X_train, X_test, y_train, y_test = train_test_split(X_train_test, y_train_test, random_state=420, test_size=0.25, stratify=y_train_test)
print(f'Zbiór treningowy: {X_train.shape}')
print(f'Zbiór testowy: {X_test.shape}')
print(f'Zbiór testowy: {X_val.shape}')
Zbiór treningowy: (25896, 61) Zbiór testowy: (8632, 61) Zbiór testowy: (8632, 61)
X_train = preprocessing.fit_transform(X_train, y_train)
X_test = preprocessing.transform(X_test)
X_val = preprocessing.transform(X_val)
datasets = (X_train, y_train, X_test, y_test, X_val, y_val)
X_train.shape
(25896, 62)
Wykorzystany model to XGBClassifier. Po optymalizacji hiperparametrów uzyskaliśmy f1-score:
Model zdaje się być przeuczony, jednak wszystkie próby naprawy tego (zmniejszanie max_depth i learning_rate, zwiększanie min_child_weight, gamma, lambda) znacząco osłabiały wynik modelu na zbiorze testowym.
clf = XGBClassifier(
n_estimators = 1000,
max_depth = 6,
min_child_weight = 4,
learning_rate = 0.1,
gamma = 0.5,
scale_pos_weight = (y_train == 0).sum() / (y_train == 1).sum(),
objective = 'binary:logistic',
eval_metric = 'logloss',
random_state = 0,
use_label_encoder=False,
**{'lambda': 1.2}
)
# clf.fit(X_train, y_train)
# joblib.dump(clf, 'search/final_model.pkl');
clf = joblib.load('search/final_model.pkl')
print_precision_recall_report(clf, *datasets)
----------Zbiór treningowy---------
0 1 accuracy
precision 1.0000 0.9881 0.9994
recall 0.9994 1.0000 0.9994
f1-score 0.9997 0.9940 0.9994
-----------Zbiór testowy-----------
0 1 accuracy
precision 0.9752 0.7247 0.9669
recall 0.9904 0.5012 0.9669
f1-score 0.9827 0.5926 0.9669
---------Zbiór walidacyjny---------
0 1 accuracy
precision 0.9735 0.6576 0.9627
recall 0.9877 0.4675 0.9627
f1-score 0.9805 0.5465 0.9627
plot_roc_auc(clf, *datasets)
Do interpretacji będziemy korzystać ze zbioru walidacyjnego.
# columns = ["net profit / total assets", "total liabilities / total assets", "working capital / total assets", "current assets / short-term liabilities", "((cash + short-term securities + receivables - short-term liabilities) / (operating expenses - depreciation)) * 365", "retained earnings / total assets", "EBIT / total assets", "book value of equity / total liabilities", "sales / total assets", "equity / total assets", "(gross profit + extraordinary items + financial expenses) / total assets", "gross profit / short-term liabilities", "(gross profit + depreciation) / sales", "(gross profit + interest) / total assets", "(total liabilities * 365) / (gross profit + depreciation)", "(gross profit + depreciation) / total liabilities", "total assets / total liabilities", "gross profit / total assets", "gross profit / sales", "(inventory * 365) / sales", "profit on operating activities / total assets", "net profit / sales", "gross profit (in 3 years) / total assets", "(equity - share capital) / total assets", "(net profit + depreciation) / total liabilities", "working capital / fixed assets", "logarithm of total assets", "(total liabilities - cash) / sales", "(gross profit + interest) / sales", "(current liabilities * 365) / cost of products sold", "operating expenses / short-term liabilities", "operating expenses / total liabilities", "profit on sales / total assets", "total sales / total assets", "constant capital / total assets", "profit on sales / sales", "(current assets - inventory - receivables) / short-term liabilities", "total liabilities / ((profit on operating activities + depreciation) * (12/365))", "profit on operating activities / sales", "rotation receivables + inventory turnover in days", "(receivables * 365) / sales", "net profit / inventory", "(current assets - inventory) / short-term liabilities", "(inventory * 365) / cost of products sold", "EBITDA (profit on operating activities - depreciation) / total assets", "EBITDA (profit on operating activities - depreciation) / sales", "current assets / total liabilities", "short-term liabilities / total assets", "(short-term liabilities * 365) / cost of products sold)", "equity / fixed assets", "constant capital / fixed assets", "working capital", "(sales - cost of products sold) / sales", "(current assets - inventory - short-term liabilities) / (sales - gross profit - depreciation)", "total costs /total sales", "long-term liabilities / equity", "sales / inventory", "sales / receivables", "(short-term liabilities *365) / sales", "sales / short-term liabilities", "sales / fixed assets", "outliers_count"]
columns = ["net profit / total assets", "total liabilities / total assets", "working capital / total assets", "current assets / short-term liabilities", "((c + sts + r - stl) / (oe - d)) * 365", "retained earnings / total assets", "EBIT / total assets", "book value of equity / total liabilities", "sales / total assets", "equity / total assets", "(gross profit + extraordinary items + financial expenses) / total assets", "gross profit / short-term liabilities", "(gross profit + depreciation) / sales", "(gross profit + interest) / total assets", "(total liabilities * 365) / (gross profit + depreciation)", "(gross profit + depreciation) / total liabilities", "total assets / total liabilities", "gross profit / total assets", "gross profit / sales", "(inventory * 365) / sales", "profit on operating activities / total assets", "net profit / sales", "gross profit (in 3 years) / total assets", "(equity - share capital) / total assets", "(net profit + depreciation) / total liabilities", "working capital / fixed assets", "logarithm of total assets", "(total liabilities - cash) / sales", "(gross profit + interest) / sales", "(current liabilities * 365) / cost of products sold", "operating expenses / short-term liabilities", "operating expenses / total liabilities", "profit on sales / total assets", "total sales / total assets", "constant capital / total assets", "profit on sales / sales", "(current assets - inventory - receivables) / short-term liabilities", "total liabilities / ((profit on operating activities + depreciation) * (12/365))", "profit on operating activities / sales", "rotation receivables + inventory turnover in days", "(receivables * 365) / sales", "net profit / inventory", "(current assets - inventory) / short-term liabilities", "(inventory * 365) / cost of products sold", "EBITDA (profit on operating activities - depreciation) / total assets", "EBITDA (profit on operating activities - depreciation) / sales", "current assets / total liabilities", "short-term liabilities / total assets", "(short-term liabilities * 365) / cost of products sold)", "equity / fixed assets", "constant capital / fixed assets", "working capital", "(sales - cost of products sold) / sales", "(current assets - inventory - short-term liabilities) / (sales - gross profit - depreciation)", "total costs /total sales", "long-term liabilities / equity", "sales / inventory", "sales / receivables", "(short-term liabilities *365) / sales", "sales / short-term liabilities", "sales / fixed assets", "outliers_count"]
X_val = pd.DataFrame(X_val, columns=columns).reset_index(drop=True)
y_val = y_val.reset_index(drop=True)
X_val
| net profit / total assets | total liabilities / total assets | working capital / total assets | current assets / short-term liabilities | ((c + sts + r - stl) / (oe - d)) * 365 | retained earnings / total assets | EBIT / total assets | book value of equity / total liabilities | sales / total assets | equity / total assets | ... | (sales - cost of products sold) / sales | (current assets - inventory - short-term liabilities) / (sales - gross profit - depreciation) | total costs /total sales | long-term liabilities / equity | sales / inventory | sales / receivables | (short-term liabilities *365) / sales | sales / short-term liabilities | sales / fixed assets | outliers_count | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1.320063 | -0.272069 | -0.348343 | -0.459759 | -0.016427 | 1.566265 | 1.122777 | -0.282882 | -0.485290 | 0.331947 | ... | 0.339413 | 0.763840 | -0.363210 | -0.289805 | -0.037083 | -0.265588 | -0.374576 | -0.225405 | -0.351983 | -0.495583 |
| 1 | 0.157050 | -0.203409 | 1.112662 | -0.128322 | 0.435468 | -0.117865 | 0.186936 | -0.311755 | 0.262703 | 0.264897 | ... | -0.146862 | 0.010615 | 0.147130 | -0.452619 | 0.100671 | -0.738400 | -0.165885 | -0.416733 | 1.863299 | -0.495583 |
| 2 | 1.155339 | -1.125665 | -0.014975 | -0.145957 | 0.326008 | -0.117865 | 1.324930 | 0.643161 | -0.885578 | 1.189662 | ... | 1.891868 | 0.336973 | -2.046946 | -0.452619 | 4.940445 | -0.508789 | -0.113290 | -0.452440 | -0.458584 | 0.824596 |
| 3 | -0.447490 | -0.519100 | -0.126166 | -0.297405 | 0.190822 | -0.117865 | -0.445440 | -0.149905 | -0.368713 | 0.581447 | ... | -0.603708 | -0.402519 | 0.470366 | -0.452619 | 0.492116 | -0.565500 | -0.365941 | -0.235430 | -0.421120 | -0.495583 |
| 4 | 1.402071 | -0.457333 | 0.073566 | -0.168415 | 0.291545 | -0.225868 | 1.622942 | -0.187995 | 0.261389 | 0.519271 | ... | 2.585966 | 0.713207 | -2.619452 | -0.127971 | 0.290779 | -0.173867 | -0.673085 | 0.358155 | -0.369342 | -0.165539 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 8627 | -0.336797 | 0.919175 | -0.499587 | -0.441619 | -0.098342 | -0.117865 | -0.352091 | -0.573689 | -0.510258 | -0.860739 | ... | -0.453497 | -0.216065 | 0.379949 | 2.743800 | -0.290862 | -0.071408 | -0.467983 | -0.100153 | -0.446381 | -0.495583 |
| 8628 | 1.082583 | -1.042287 | 2.001398 | 0.900720 | 3.557190 | 1.889775 | 1.236265 | 0.449998 | 3.377097 | 1.106057 | ... | -0.210896 | 0.324720 | 0.206801 | -0.452619 | 2.184442 | -0.204469 | -0.977266 | 3.445843 | 4.721358 | 0.989618 |
| 8629 | -0.461711 | -0.083874 | -0.593727 | -0.509498 | -0.295706 | -0.117865 | -0.512115 | -0.356731 | -0.785398 | 0.145038 | ... | -0.279979 | -0.412081 | 0.526531 | 0.392427 | -0.393317 | -0.121833 | -0.044479 | -0.493902 | -0.460802 | -0.495583 |
| 8630 | 0.182973 | -1.026501 | 0.721249 | 0.232303 | 0.231704 | -0.117865 | 0.266268 | 0.418712 | 0.406220 | 1.090228 | ... | -0.160754 | -0.101500 | -0.068899 | -0.452619 | -0.156703 | -0.305329 | -0.743170 | 0.628756 | -0.296584 | -0.495583 |
| 8631 | -0.328639 | -0.634498 | -0.728240 | -0.571268 | -0.332952 | -0.117865 | -0.390013 | -0.067272 | -0.975592 | 0.697159 | ... | 0.412475 | -0.336374 | -0.343117 | -0.452619 | -0.365866 | -0.685934 | 0.914965 | -0.780004 | -0.468939 | -0.330561 |
8632 rows × 62 columns
explainer = shap.TreeExplainer(clf)
shap_values = explainer.shap_values(X_val)
ntree_limit is deprecated, use `iteration_range` or model slicing instead.
shap.summary_plot(shap_values, features=X_val, feature_names=X_val.columns)
Kolumny na wykresie są uporządkowane od tych co mają największy wpływ do tych z najmniejszym. Rzuca się w oczy często powtarzający się mianownik total assets.
shap.initjs()
shap.force_plot(explainer.expected_value, shap_values[0,:], X_val.iloc[0,:])
import dalex as dx
explainer2 = dx.Explainer(clf, data=X_val, y=y_val)
Preparation of a new explainer is initiated -> data : 8632 rows 62 cols -> target variable : Parameter 'y' was a pandas.Series. Converted to a numpy.ndarray. -> target variable : 8632 values -> model_class : xgboost.sklearn.XGBClassifier (default) -> label : Not specified, model's class short name will be used. (default) -> predict function : <function yhat_proba_default at 0x00000225FD3FA5E0> will be used (default) -> predict function : Accepts pandas.DataFrame and numpy.ndarray. -> predicted values : min = 1.74e-07, mean = 0.0482, max = 1.0 -> model type : classification will be used (default) -> residual function : difference between y and yhat (default) -> residuals : min = -0.988, mean = -0.000111, max = 1.0 -> model_info : package xgboost A new explainer has been created!
Wybieramy trzy obserwacje: A i B, gdzie firma nie zbankrutowała, oraz C, gdzie firma zbankrutowała.
A = X_val[y_val == 0].iloc[[0]]
B = X_val[y_val == 0].iloc[[1]]
C = X_val[y_val == 1].iloc[[0]]
explainer2.predict_parts(A, type="break_down", label='A').plot()
explainer2.predict_parts(A, type="shap", B=10, label='A').plot()
explainer2.predict_parts(B, type="break_down", label='B').plot()
explainer2.predict_parts(B, type="shap", B=10, label='B').plot()
explainer2.predict_parts(C, type="break_down", label='C').plot()
explainer2.predict_parts(C, type="shap", B=10, label='C').plot()
Na powyższych wykresach najczęściej powtarza się zmienna current assets - inventory / short-term liabilities i to ona popycha predykcję we właściwym kierunku. Jednak warto zauważyć, że wszystkie pozostałe zmiennne all other factors zwykle stanowi całkiem potężny czynnik. To chyba znaczy, że dużo zmiennych jest ważnych, ma wpływ.
Spójrzmy na variable importance oraz wykresy ceteris paribus.
vi = explainer2.model_parts(random_state=0)
variables = vi.result.sort_values(by='dropout_loss', ascending=False).iloc[1:11]['variable'].to_numpy()
vi.plot(show=False)
Mocno rzuca się w oczy bardzo znikomy wpływ poszczególnych zmiennych - jedna zmienna chyba nie ma mocy całkowicie zmienić predykcji.
pdp = explainer2.model_profile(type='partial', variables=variables)
pdp.plot()
Calculating ceteris paribus: 100%|█████████████████████████████████████████████████████| 10/10 [00:00<00:00, 10.23it/s]
Choć większość wykresów przedstawia proste linie, to jednak w kilku przypadkach widzimy logiczną zależność, a mianowicie, im niższe:
current assets - inventory / short-term liabilitiessales / total asstes
lub wyższe gross profit (in 3 years) / total assets, totym model jest bardziej skłonny uznać, że firma zbankrutuje. current assets - inventory / short-term liabilitiesgross profit (in 3 years) / total assets